package authservice

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"gitee.com/go-mid/auth/internal/component"
	"gitee.com/go-mid/auth/internal/entity"
	"gitee.com/go-mid/booter/web"
	"gitee.com/go-mid/infra/xcache/xredis"
	"gitee.com/go-mid/infra/xcontext"
	"gitee.com/go-mid/infra/xlog"
	"gitee.com/go-mid/infra/xsql/xdb"
	"gitee.com/go-mid/token/tokenservice"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"time"
)

type AccessToken struct {
	Atk *TokenInfo `json:"-"`
	Rtk *TokenInfo `json:"-"`
}
type RefreshTokenPayload struct {
	UserId      int64  `json:"userId"`
	UserPoolID  int64  `json:"user_pool_id"`
	AccessToken string `json:"access_token"`
}
type TokenInfo struct {
	Expire int64  `json:"expire"`
	Token  string `json:"token"`
}

//token服务给当前服务分配的accesstoken标识
func getAtkOpenKey(ctx context.Context) string {
	return "xingyuAtk"
}

//token服务给当前服务分配的refreshtoken标识
func getRtkOpenKey(ctx context.Context) string {
	return "xingyuRtk"
}

//client信息最长保留内容
func getMaxLenClient(ctx context.Context) int {
	return 512
}
func (m *AuthServiceImpl) buildToken(ctx context.Context, key string, payload interface{}) (*TokenInfo, error) {
	fun := "AuthServiceImpl.buildToken -->"
	payloadData, err := json.Marshal(payload)
	if err != nil {
		xlog.Warnf(ctx, "%s marsh payload err: key: %s, payload: %v, err: %v", fun, key, payload, err)
		return nil, status.Error(codes.PermissionDenied, err.Error())
	}
	res, err := tokenservice.DefaultTokenService.GetToken(ctx, &tokenservice.GetTokenReq{
		Issuer:  key,
		Payload: string(payloadData),
	})
	if err != nil {
		xlog.Errorf(ctx, "%s create token err: key: %s, payload: %v, err: %v", fun, key, payload, res)
		return nil, err
	}
	return &TokenInfo{
		Expire: res.Expire,
		Token:  res.Token,
	}, nil
}
func (m *AuthServiceImpl) buildAccessToken(ctx context.Context, req *GetTokenReq) (*AccessToken, error) {
	fun := "buildAccessToken -->"
	atkKey := getAtkOpenKey(ctx)
	atkPayload := &web.XFRMAuthInfo{
		UserId:     req.UserId,
		AuthType:   req.AuthType,
		Appid:      req.Appid,
		UserPoolID: req.UserPoolID,
	}
	atk, err := m.buildToken(ctx, atkKey, atkPayload)
	if err != nil {
		xlog.Errorf(ctx, "%s create atk token fail: req: %v, err: %v", fun, req, err)
		return nil, err
	}
	if atk == nil {
		xlog.Errorf(ctx, "%s create atk token fail: req: %v", fun, req)
		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("atk invalid"))
	}

	rtkKey := getRtkOpenKey(ctx)
	rtkPayload := &RefreshTokenPayload{
		UserId:      req.UserId,
		AccessToken: atk.Token,
		UserPoolID:  req.UserPoolID,
	}
	rtk, err := m.buildToken(ctx, rtkKey, rtkPayload)
	if err != nil {
		xlog.Errorf(ctx, "%s create rtk token fail: req: %v, err: %v", fun, req, err)
		return nil, err
	}
	if rtk == nil {
		xlog.Errorf(ctx, "%s create rtk token fail: req: %v", fun, req)
		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("rtk invalid"))
	}

	return &AccessToken{
		Atk: atk,
		Rtk: rtk,
	}, nil
}
func (m *AuthServiceImpl) verifyAccessToken(ctx context.Context, key, tk string, payload interface{}) (bool, error) {
	fun := "AuthServiceImpl.verifyAccessToken -->"
	if key == "" || tk == "" {
		xlog.Warnf(ctx, "%s  token empty: key: %s, token: %v", fun, key, tk)
		return false, fmt.Errorf("empty token")
	}
	res, err := tokenservice.DefaultTokenService.VerifyToken(ctx, &tokenservice.VerifyTokenReq{
		Issuer: key,
		Token:  tk,
	})
	if err != nil {
		xlog.Errorf(ctx, "%s verify token err: key: %s, token: %v, err: %v", fun, key, tk, res)
		return false, err
	}
	err = json.Unmarshal([]byte(res.Payload), payload)
	if err != nil {
		xlog.Warnf(ctx, "%s umarshal payload error: res: %v", fun, res)
	}
	return res.IsValid, nil
}
func atkCacheKey(ctx context.Context, token string) string {
	return fmt.Sprintf("atk_%s", token)
}

// 获取一个token
func (m *AuthServiceImpl) getTokenForceRefresh(ctx context.Context, req *GetTokenReq) (*GetTokenRes, error) {
	fun := "AuthServiceImpl.getTokenForceRefresh -->"
	//构造新的token
	accessToken, err := m.buildAccessToken(ctx, req)
	if err != nil {
		xlog.Warnf(ctx, "%s build atk err: req: %v, err: %v", fun, req, err)
		return nil, err
	}
	if accessToken == nil || accessToken.Atk == nil || accessToken.Rtk == nil {
		xlog.Warnf(ctx, "%s build atk fail: req: %v ", fun, req)
		return nil, status.Error(codes.Internal, fmt.Sprintf("create token fail"))
	}
	now := time.Now().Unix()
	clientStr := req.ClientInfo.ToHeaderValue()
	maxClient := getMaxLenClient(ctx)
	if len(clientStr) > maxClient {
		xlog.Warnf(ctx, "%s client too long clear it: req: %v", fun, req)
		clientStr = ""
	}
	atk := accessToken.Atk
	rtk := accessToken.Rtk
	var item = &entity.AuthToken{
		Ct:                 now,
		Ut:                 now,
		UserID:             req.UserId,
		Appid:              int64(req.Appid),
		AuthType:           int32(req.AuthType),
		AccessToken:        atk.Token,
		AccessTokenExpire:  atk.Expire,
		RefreshToken:       rtk.Token,
		RefreshTokenExpire: rtk.Expire,
		ClientInfo:         clientStr,
		UserPoolID:         req.UserPoolID,
	}
	ups, err := xdb.BuildDbSqlMap(ctx, item, true)
	if err != nil {
		xlog.Warnf(ctx, "%s BuildDbSqlMap err: req: %v, err: %v ", fun, req, err)
		return nil, status.Error(codes.Internal, fmt.Sprintf("build insert sql error"))
	}
	//将token保存到db
	_, err = component.XDBAuth.Insert(ctx, entity.AuthTokenTable, []map[string]interface{}{
		ups,
	})
	if err != nil {
		xlog.Errorf(ctx, "%s add token db error: req: %v, err: %v", fun, req, err)
		//	db保存失败了，此处降级
		//	保存失败了，已经有token了，就把token返回给请求方
	}
	//将token保存到cache
	m.setTokenCache(ctx, item)

	//处理登录个数限制
	go func() {
		//	根据 userId， source, authtype, did 等个数限制，来处理已经存在的token，比如：限制用户级别一个
		m.processLimitToken(xcontext.NewValueContext(ctx), &processLimitTokenReq{
			UserId:     req.UserId,
			UserPoolID: req.UserPoolID,
			AuthType:   req.AuthType,
			Did:        req.ClientInfo.DeviceId,
		})
	}()
	xlog.Infof(ctx, "%s get token req: %+v, atk: %+v", fun, req, item)
	return &GetTokenRes{
		AccessToken:        atk.Token,
		AccessTokenExpire:  atk.Expire,
		RefreshToken:       rtk.Token,
		RefreshTokenExpire: rtk.Expire,
	}, nil
}

// 将token放入缓存
//token对象放入缓存，并且更新用户当前的token
func (m *AuthServiceImpl) setTokenCache(ctx context.Context, authToken *entity.AuthToken) {
	fun := "AuthServiceImpl.setTokenCache -->"
	if authToken == nil {
		return
	}
	now := time.Now().Unix()
	expiredSec := authToken.AccessTokenExpire - now
	if expiredSec <= 0 {
		xlog.Warnf(ctx, "%s authtoken expired not set cache: authToken: %v, now: %d", fun, authToken, now)
		return
	}
	//比实际的token 多10秒的寿命
	expiredSecDuration := time.Duration(expiredSec+10) * time.Second
	atkKey := atkCacheKey(ctx, authToken.AccessToken)
	atkVal, err := json.Marshal(authToken)
	if err != nil {
		xlog.Warnf(ctx, "%s marsh atk  error: authToken: %+v, err: %v", fun, authToken, err)
	} else {
		//todo 写入到缓存
		xlog.Infof(ctx, "%s write to cache atk: atkKey: %s, atkVal: %s, expiredSecDuration: %d", fun, atkKey, atkVal, expiredSecDuration)
		//_, err := RedisClient.Set(ctx, atkKey, string(atkVal), expiredSecDuration)
		//if err != nil {
		//	xlog.Warnf(ctx, "%s set token cache error: authToken: %v, err: %v", fun, authToken, err)
		//}
	}
	return
}
func (m *AuthServiceImpl) getTokenCache(ctx context.Context, tk string) (*entity.AuthToken, error) {
	fun := "AuthServiceImpl.getTokenCache -->"
	atkKey := atkCacheKey(ctx, tk)
	xlog.Infof(ctx, "%s get cache: atkkey: %s, ", fun, atkKey)
	//todo 目前没有缓存
	return nil, errors.New(xredis.RedisNil)
}
func (m *AuthServiceImpl) delTokenCache(ctx context.Context, tk string) error {
	fun := "AuthServiceImpl.delTokenCache -->"
	atkKey := atkCacheKey(ctx, tk)
	xlog.Infof(ctx, "%s del cache: atkkey: %s, ", fun, atkKey)
	//todo 目前没有缓存
	return nil
}
func (m *AuthServiceImpl) delAuthToken(ctx context.Context, tk string) error {
	fun := "AuthServiceImpl.delAuthToken -->"
	//先删数据库记录
	_, err := component.XDBAuth.Delete(ctx, entity.AuthTokenTable, map[string]interface{}{
		"access_token": tk,
	})
	if err != nil {
		xlog.Warnf(ctx, "%s delete authtoken error: tk: %s, err: %v", fun, tk, err)
		return err
	}
	//然后删除缓存记录
	m.delTokenCache(ctx, tk)
	return nil
}

type processLimitTokenReq struct {
	UserId     int64
	UserPoolID int64
	Appid      int
	AuthType   web.EnumAuthType
	Did        string
}
type whereCount struct {
	where map[string]interface{}
	count int
}

// 处理已有token
func (m *AuthServiceImpl) processLimitToken(ctx context.Context, req *processLimitTokenReq) {
	fun := "AuthServiceImpl.processLimitToken -->"
	var wheres []*whereCount
	//用户维度的个数限制
	if c, ok := getUserTokenCounts(ctx, req.Appid, req.AuthType); ok && c > 0 {
		var where = map[string]interface{}{
			"appid":        req.Appid,
			"auth_type":    req.AuthType,
			"user_id":      req.UserId,
			"user_pool_id": req.UserPoolID,
			"_orderby":     "ct desc",
			"_limit":       []int{0, 100},
		}
		wheres = append(wheres, &whereCount{where, c})
	}
	if len(wheres) == 0 {
		return
	}

	//  按照用户在该sourc和authtype下的token全部列出来，按照时间倒叙。超过 个数限制的 的部分，全部清掉。
	//	先清除数据库，再清除缓存
	for _, whereInfo := range wheres {
		var items []*entity.AuthToken
		err := component.XDBAuth.Select(ctx, entity.AuthTokenTable, whereInfo.where, &items)
		if err != nil {
			xlog.Warnf(ctx, "%s list authtoken err: where: %v, err: %v", fun, whereInfo, err)
			continue
		}
		if len(items) <= whereInfo.count {
			continue
		}
		delItems := items[whereInfo.count:]
		for _, delItem := range delItems {
			err := m.delAuthToken(ctx, delItem.AccessToken)
			if err != nil {
				xlog.Warnf(ctx, "%s delauthtoken err: tk: %s, err: %v", fun, delItem.AccessToken, err)
				continue
			}
		}
	}
	return
}

func getUserTokenCounts(ctx context.Context, appid int, authtype web.EnumAuthType) (int, bool) {
	return getTokenCountLimit(ctx, appid, authtype, false)
}
func getUserDidTokenCounts(ctx context.Context, appid int, authtype web.EnumAuthType) (int, bool) {
	return getTokenCountLimit(ctx, appid, authtype, true)
}

var tokenLimitConfig = map[string]int{
	"token.counts":     2,
	"token.counts.1.1": 1,
}

func getTokenCountLimit(ctx context.Context, appid int, authtype web.EnumAuthType, isDid bool) (int, bool) {
	var key = fmt.Sprintf("token.counts")
	if isDid {
		key = fmt.Sprintf("token.did.counts")
	}
	if appid != 0 && authtype != 0 {
		key = fmt.Sprintf("%s.%d.%d", key, appid, authtype)
	}
	if c, ok := tokenLimitConfig[key]; ok && c > 0 {
		return c, true
	}
	return 0, false

}
